import os
import sys
import torch
import gzip
import itertools
import random
import numpy
import math
import json
import torchvision
from PIL import Image
from torch import nn
from matplotlib import pyplot
from collections import defaultdict
from functools import lru_cache

# 缩放图片的大小
IMAGE_SIZE = (80, 80)
# 训练使用的数据集路径
DATASET_1_DIR = "./dataset/lfwpeople/lfw_funneled"
DATASET_2_DIR = "./dataset/face-discern-dataset/Faces/Faces"
DATASET_3_DIR = "./dataset/105_classes_pins_dataset"
# 每一轮训练中样本的重复次数
REPEAT_SAMPLES = 2
# 用于对比的不同人物 (负样本) 数量
NEGATIVE_SAMPLES = 10
# 负样本中随机抽取的数量
NEGATIVE_RANDOM_SAMPLES = 3
# 跳过最接近的人脸数量
# 避免双胞胎问题：
# 如果你给模型大量很相似的人脸 (有可能因为误标记，有可能因为图片质量很低，也有可能因为真相似)
# 然后跟模型说不是同一个人，下次模型看到未经过训练的同一个人也会认为不是
# Facenet 论文中避免这个问题使用的方法是计算局部最接近的不同人物
# 而这里会计算全局最接近但跳过排在前面的人脸，数据量不多的时候可以这么做
NEGATIVE_SKIP_NEAREST = 20
# 识别同一人物最少要求的图片数量
MINIMAL_POSITIVE_SAMPLES = 2
# 处理图片前是否先转换为黑白图片
USE_GRAYSCALE = True

# 用于启用 GPU 支持

DEVICE_TYPE = "cuda" if torch.cuda.is_available() else "cpu"

device = torch.device(DEVICE_TYPE)

class FaceRecognitionModel(nn.Module):
    """人脸识别模型，计算用于寻找最接近人脸的编码 (基于 ResNet 的变种)"""
    # 编码长度
    EmbeddedSize = 32
    # 要求不同人物编码之间的距离 (平方值合计)
    ExclusiveMargin = 0.2

    def __init__(self):
        super().__init__()
        # Resnet 的实现
        self.resnet = torchvision.models.resnet18(num_classes=256)
        # 支持黑白图片
        if USE_GRAYSCALE:
            self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # 最终输出编码的线性模型
        # 因为 torchvision 的 resnet 最终会使用一个 Linear，这里省略掉第一个 Linear
        self.encode_model = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, FaceRecognitionModel.EmbeddedSize))

    def forward(self, x):
        tmp = self.resnet(x)
        y = self.encode_model(tmp)
        return y

    @staticmethod
    def loss_function(predicted):
        """损失计算器"""
        losses = []
        verify_positive = torch.ones(1).to(device)
        verify_negative = torch.zeros(NEGATIVE_SAMPLES).to(device)
        for index in range(0, predicted.shape[0], 2 + NEGATIVE_SAMPLES):
            a = predicted[index]   # 基础人物的编码
            b = predicted[index+1] # 基础人物的编码 (另一张图片)
            c = predicted[index+2:index+2+NEGATIVE_SAMPLES] # 对比人物的编码
            # 计算编码相差值
            diff_positive = (a - b).pow(2).sum()
            diff_negative = (a - c).pow(2).sum(dim=1)
            # 计算损失
            # 使用 Triplet Loss，要求同一人物编码距离和不同人物编码距离至少相差 ExclusiveMargin
            loss = nn.functional.relu(
                diff_positive - diff_negative + FaceRecognitionModel.ExclusiveMargin).sum()
            losses.append(loss)
        loss_total = torch.stack(losses).mean()
        return loss_total

    @staticmethod
    def calc_accuracy(predicted):
        """正确率计算器"""
        total_count = 0
        correct_count = 0
        for index in range(0, predicted.shape[0], 2 + NEGATIVE_SAMPLES):
            a = predicted[index]   # 基础人物的编码
            b = predicted[index+1] # 基础人物的编码 (另一张图片)
            c = predicted[index+2:index+2+NEGATIVE_SAMPLES] # 对比人物的编码
            # 判断同一人物的编码是否小于不同人物的编码
            diff_positive = (a - b).pow(2).sum()
            diff_negative = (a - c).pow(2).sum(dim=1)
            if (diff_positive < diff_negative).sum() == diff_negative.shape[0]:
                correct_count += 1
            total_count += 1
        return correct_count / total_count

class FaceVerificationModel(nn.Module):
    """人脸认证模型，判断是否同一个人，参数是编码相差值的平方"""
    # 判断是否同一个人的阈值，实际使用模型时可以用更高的值防止误判
    VerifyThreshold = 0.5

    def __init__(self):
        super().__init__()
        # 判断是否同一人物的线性模型
        self.verify_model = nn.Sequential(
            nn.Linear(FaceRecognitionModel.EmbeddedSize, 1),
            nn.Sigmoid())

    def forward(self, x):
        # 经过训练后 weight 应该是负数，bias 应该是正数
        y = self.verify_model(x)
        return y.view(-1)

    @staticmethod
    def loss_function(predicted):
        """损失计算器"""
        # 输出应该为 [ 同一人物, 不同人物, 不同人物, ..., 同一人物, 不同人物, 不同人物, ... ]
        # 这里需要分别计算正负损失，否则会因为负样本占多而引起 bias 被调整为负数
        positive_indexes = []
        negative_indexes = []
        for index in list(range(0, predicted.shape[0], 1+NEGATIVE_SAMPLES)):
            positive_indexes.append(index)
            negative_indexes += list(range(index+1, index+1+NEGATIVE_SAMPLES))
        positive_loss = nn.functional.mse_loss(
            predicted[positive_indexes], torch.ones(len(positive_indexes)).to(device))
        negative_loss = nn.functional.mse_loss(
            predicted[negative_indexes], torch.zeros(len(negative_indexes)).to(device))
        return (positive_loss + negative_loss) / 2

    @staticmethod
    def calc_accuracy(predicted):
        """正确率计算器"""
        positive_correct = 0
        positive_total = 0
        negative_correct = 0
        negative_total = 0
        for index in range(0, predicted.shape[0], 1+NEGATIVE_SAMPLES):
            positive_correct += (predicted[index] >=
                                 FaceVerificationModel.VerifyThreshold).sum().item()
            negative_correct += (predicted[index+1:index+1+NEGATIVE_SAMPLES] <
                                 FaceVerificationModel.VerifyThreshold).sum().item()
            positive_total += 1
            negative_total += NEGATIVE_SAMPLES
        # 因为负样本占大多数，这里返回正样本正确率和负样本正确率的平均值
        return (positive_correct / positive_total + negative_correct / negative_total) / 2

def save_tensor(tensor, path):
    """保存 tensor 对象到文件"""
    torch.save(tensor, gzip.GzipFile(path, "wb"))

# 为了减少读取时间这里缓存了读取的 tensor 对象
# 如果内存不够应该适当减少 maxsize
@lru_cache(maxsize=10000)
def load_tensor(path):
    """从文件读取 tensor 对象"""
    return torch.load(gzip.GzipFile(path, "rb"), map_location=DEVICE_TYPE)

def calc_resize_parameters(sw, sh):
    """计算缩放图片的参数"""
    sw_new, sh_new = sw, sh
    dw, dh = IMAGE_SIZE
    pad_w, pad_h = 0, 0
    if sw / sh < dw / dh:
        sw_new = int(dw / dh * sh)
        pad_w = (sw_new - sw) // 2 # 填充左右
    else:
        sh_new = int(dh / dw * sw)
        pad_h = (sh_new - sh) // 2 # 填充上下
    return sw_new, sh_new, pad_w, pad_h

def resize_image(img):
    """缩放图片，比例不一致时填充"""
    sw, sh = img.size
    sw_new, sh_new, pad_w, pad_h = calc_resize_parameters(sw, sh)
    img_new = Image.new("RGB", (sw_new, sh_new))
    img_new.paste(img, (pad_w, pad_h))
    img_new = img_new.resize(IMAGE_SIZE)
    return img_new

def image_to_tensor_grayscale(img):
    """缩放并转换图片对象到 tensor 对象 (黑白)"""
    img = img.convert("L") # 转换到黑白图片并缩放
    arr = numpy.asarray(img)
    t = torch.from_numpy(arr)
    t = t.unsqueeze(0) # 添加通道
    t = t / 255.0 # 正规化数值使得范围在 0 ~ 1
    return t

def image_to_tensor_rgb(img):
    """缩放并转换图片对象到 tensor 对象 (彩色)"""
    img = img.convert("RGB") # 缩放图片
    arr = numpy.asarray(img)
    t = torch.from_numpy(arr)
    t = t.transpose(0, 2) # 转换维度 H,W,C 到 C,W,H
    t = t / 255.0 # 正规化数值使得范围在 0 ~ 1
    return t

if USE_GRAYSCALE:
    image_to_tensor = image_to_tensor_grayscale
else:
    image_to_tensor = image_to_tensor_rgb

def prepare():
    """准备训练"""
    # 数据集转换到 tensor 以后会保存在 data 文件夹下
    if not os.path.isdir("data"):
        os.makedirs("data")

    # 截取后的人脸图片会保存在 debug_faces 文件夹下
    if not os.path.isdir("debug_faces"):
        os.makedirs("debug_faces")

    # 查找人物和对应的图片路径列表
    # { 人物名称: [ 图片路径, 图片路径, .. ] }
    images_map = defaultdict(lambda: [])
    def add_image(name, path):
        if os.path.splitext(path)[1].lower() not in (".jpg", ".png"):
            return
        name = name.replace(" ", "").replace("-", "").replace(".", "").replace("_", "").lower()
        images_map[name].append(path)
    for dirname in os.listdir(DATASET_1_DIR):
        dirpath = os.path.join(DATASET_1_DIR, dirname)
        if not os.path.isdir(dirpath):
            continue
        for filename in os.listdir(dirpath):
            add_image(dirname, os.path.join(DATASET_1_DIR, dirname, filename))
    for filename in os.listdir(DATASET_2_DIR):
        add_image(filename.split("_")[0], os.path.join(DATASET_2_DIR, filename))
    for dirname in os.listdir(DATASET_3_DIR):
        dirpath = os.path.join(DATASET_3_DIR, dirname)
        name = dirname.replace("pins_", "")
        if not os.path.isdir(dirpath):
            continue
        for filename in os.listdir(dirpath):
            add_image(name, os.path.join(DATASET_3_DIR, dirname, filename))
    images_count = sum(map(len, images_map.values()))
    print(f"found {len(images_map)} peoples and {images_count} images")

    # 保存各个人物的图片数据
    # 这里不翻转图片，因为人脸照片通常不会左右翻转，而且部分器官的特征会因左右有差异
    img_index = 0
    for index, (name, paths) in enumerate(images_map.items()):
        images = []
        for path in paths:
            img = Image.open(path)
            # 裁剪图片让各个数据集的人脸占比更接近
            if path.startswith(DATASET_1_DIR):
                w, h = img.size
                img = img.crop((int(w*0.25), int(h*0.25), int(w*0.75), int(h*0.75)))
            elif path.startswith(DATASET_3_DIR):
                w, h = img.size
                img = img.crop((int(w*0.15), int(h*0.15), int(w*0.85), int(h*0.85)))
            # 保存截取后的人脸图片以调试范围
            img.save(f"debug_faces/{img_index}.png")
            img_index += 1
            images.append(img)
        tensors = [ image_to_tensor(resize_image(img)) for img in images ]
        tensor = torch.stack(tensors) # 维度: (图片数量, 3, 宽度, 高度)
        save_tensor(tensor, os.path.join("data", f"{name}.{len(images)}.pt"))
        print(f"saved {index+1}/{len(images_map)} peoples")

    print("done")

def train():
    """开始训练人脸识别模型"""
    # 创建模型实例
    model = FaceRecognitionModel().to(device)

    # 创建损失计算器
    loss_function = model.loss_function

    # 创建参数调整器
    optimizer = torch.optim.Adam(model.parameters())

    # 记录训练集和验证集的正确率变化
    training_accuracy_history = []
    validating_accuracy_history = []

    # 记录最高的验证集正确率
    validating_accuracy_highest = -1
    validating_accuracy_highest_epoch = 0

    # 计算正确率的工具函数
    calc_accuracy = model.calc_accuracy

    # 读取人物列表，区分图片数量足够的人物和图片数量不足的人物
    # 图片数量不足的人物会作为负样本使用
    filenames = os.listdir("data")
    multiple_samples = []
    single_samples = []
    for filename in filenames:
        if int(filename.split('.')[-2]) >= MINIMAL_POSITIVE_SAMPLES:
            multiple_samples.append(filename)
        else:
            single_samples.append(filename)
    random.shuffle(multiple_samples)
    random.shuffle(single_samples)
    total_multiple_samples = len(multiple_samples)
    total_single_samples = len(single_samples)

    # 分割训练集 (80%)，验证集 (10%) 和测试集 (10%)
    training_set = multiple_samples[:int(total_multiple_samples*0.8)]
    training_set_single = single_samples[:int(total_single_samples*0.8)]
    validating_set = multiple_samples[int(total_multiple_samples*0.8):int(total_multiple_samples*0.9)]
    validating_set_single = single_samples[int(total_single_samples*0.8):int(total_single_samples*0.9)]
    testing_set = multiple_samples[int(total_multiple_samples*0.9):]
    testing_set_single = single_samples[int(total_single_samples*0.9):]

    # 训练集的各个人物对应的编码 (基于最后以后一次训练使用的图片)
    training_image_to_vector_index = {}
    training_vector_index_to_image = {}
    for filename in training_set + training_set_single:
        for image_index in range(int(filename.split('.')[1])):
            vector_index = len(training_image_to_vector_index)
            training_image_to_vector_index[(filename, image_index)] = vector_index
            training_vector_index_to_image[vector_index] = (filename, image_index)
    training_vectors = torch.zeros(len(training_image_to_vector_index), FaceRecognitionModel.EmbeddedSize)
    training_vectors_calculated_indices = set()

    # 生成用于训练的输入
    # 返回 [ 基础图片, 同一人物图片 (正样本), 不同人物图片 (负样本), ... ]
    def generate_inputs(dataset_multiple, dataset_single, batch_size):
        # 获取已计算过的编码
        is_training = dataset_multiple == training_set
        if is_training:
            calculated_index_list = list(training_vectors_calculated_indices)
            calculated_index_set = set(calculated_index_list)
            calculated_index_to_image = {
                ci: training_vector_index_to_image[vi]
                for ci, vi in enumerate(calculated_index_list)
            }
            training_vectors_calculated = training_vectors[calculated_index_list]
        # 枚举数据集，会重复 REPEAT_SAMPLES 次以减少随机选择导致的正确率浮动
        image_tensors = []
        vector_indices = []
        for base_filename in dataset_multiple * REPEAT_SAMPLES:
            # 读取基础人物的图片
            base_tensor = load_tensor(os.path.join("data", base_filename))
            base_tensors = list(enumerate(base_tensor))
            # 打乱顺序，然后两张两张图片的选取基础图片和正样本
            random.shuffle(base_tensors)
            for index in range(0, len(base_tensors)-1, 2):
                # 添加基础图片和正样本到列表
                anchor_image_index, anchor_tensor = base_tensors[index]
                positive_image_index, positive_tensor = base_tensors[index+1]
                image_tensors.append(anchor_tensor)
                image_tensors.append(positive_tensor)
                if is_training:
                    vector_indices.append(training_image_to_vector_index[(base_filename, anchor_image_index)])
                    vector_indices.append(training_image_to_vector_index[(base_filename, positive_image_index)])
                # 如果是训练集，则计算基础图片的编码与其他编码的距离
                nearest_indices = []
                if is_training:
                    vector_index = training_image_to_vector_index[(base_filename, anchor_image_index)]
                    if vector_index in calculated_index_set:
                        nearest_indices = ((training_vectors_calculated -
                                            training_vectors[vector_index]).abs().sum(dim=1).sort().indices).tolist()
                # 选取负样本
                # 如果是训练集则选取编码最接近的样本+随机样本作为负样本
                # 如果是验证集和测试集则随机选取样本
                if is_training and nearest_indices:
                    negative_samples = NEGATIVE_SAMPLES - NEGATIVE_RANDOM_SAMPLES
                    negative_random_samples = NEGATIVE_RANDOM_SAMPLES
                else:
                    negative_samples = 0
                    negative_random_samples = NEGATIVE_SAMPLES
                negative_skip_nearest = NEGATIVE_SKIP_NEAREST
                for calculated_index in nearest_indices:
                    if negative_samples <= 0:
                        break
                    filename, image_index = calculated_index_to_image[calculated_index]
                    if filename == base_filename:
                        continue # 跳过同一人物
                    if negative_skip_nearest > 0:
                        negative_skip_nearest -= 1
                        continue # 跳过非常相似的人物
                    target_tensor = load_tensor(os.path.join("data", filename))
                    # 添加负样本到列表
                    image_tensors.append(target_tensor[image_index])
                    if is_training:
                        vector_indices.append(training_image_to_vector_index[(filename, image_index)])
                    negative_samples -= 1
                while negative_random_samples > 0:
                    file_index = random.randint(0, len(dataset_multiple) + len(dataset_single) - 1)
                    if file_index < len(dataset_multiple):
                        filename = dataset_multiple[file_index]
                    else:
                        filename = dataset_single[file_index - len(dataset_multiple)]
                    if filename == base_filename:
                        continue # 跳过同一人物
                    target_tensor = load_tensor(os.path.join("data", filename))
                    image_index = random.randint(0, target_tensor.shape[0] - 1)
                    # 添加负样本到列表
                    image_tensors.append(target_tensor[image_index])
                    if is_training:
                        vector_indices.append(training_image_to_vector_index[(filename, image_index)])
                    negative_random_samples -= 1
                assert negative_samples == 0
                assert negative_random_samples == 0
                # 如果图片数量大于批次大小，则返回批次
                if len(image_tensors) >= batch_size:
                    yield torch.stack(image_tensors).to(device), vector_indices
                    image_tensors.clear()
                    vector_indices.clear()
        if image_tensors:
            yield torch.stack(image_tensors).to(device), vector_indices

    # 开始训练过程
    for epoch in range(0, 200):
        print(f"epoch: {epoch}")

        # 根据训练集训练并修改参数
        # 切换模型到训练模式
        model.train()
        training_accuracy_list = []
        for index, (batch_x, vector_indices) in enumerate(
                generate_inputs(training_set, training_set_single, 400)):
            # 计算预测值
            predicted = model(batch_x)
            # 计算损失
            loss = loss_function(predicted)
            # 从损失自动微分求导函数值
            loss.backward()
            # 使用参数调整器调整参数
            optimizer.step()
            # 清空导函数值
            optimizer.zero_grad()
            # 记录各个人物的编码
            for vector_index, vector in zip(vector_indices, predicted):
                # 复制回 cpu 并去掉用于自动微分的计算路径信息
                training_vectors[vector_index] = vector.to("cpu").detach()
                training_vectors_calculated_indices.add(vector_index)
            # 记录这一个批次的正确率，torch.no_grad 代表临时禁用自动微分功能
            with torch.no_grad():
                training_batch_accuracy = calc_accuracy(predicted)
            # 输出批次正确率
            training_accuracy_list.append(training_batch_accuracy)
            print(f"epoch: {epoch}, batch: {index}, accuracy: {training_batch_accuracy}")
        training_accuracy = sum(training_accuracy_list) / len(training_accuracy_list)
        training_accuracy_history.append(training_accuracy)
        print(f"training accuracy: {training_accuracy}")

        # 检查验证集
        # 切换模型到验证模式
        model.eval()
        validating_accuracy_list = []
        for batch_x, _ in generate_inputs(validating_set, validating_set_single, 100):
            predicted = model(batch_x)
            validating_batch_accuracy = calc_accuracy(predicted)
            validating_accuracy_list.append(validating_batch_accuracy)
            # 释放 predicted 占用的显存避免显存不足的错误
            predicted = None
        validating_accuracy = sum(validating_accuracy_list) / len(validating_accuracy_list)
        validating_accuracy_history.append(validating_accuracy)
        print(f"validating accuracy: {validating_accuracy}")

        # 记录最高的验证集正确率与当时的模型状态，判断是否在多次训练后仍然没有刷新记录
        # 因为验证集的负样本是随机选择的，允许 1% 的波动使得模型可以训练更多次
        if (validating_accuracy + 0.01) > validating_accuracy_highest:
            if validating_accuracy > validating_accuracy_highest:
                validating_accuracy_highest = validating_accuracy
                print("highest validating accuracy updated")
            else:
                print("highest validating accuracy not dropped")
            validating_accuracy_highest_epoch = epoch
            save_tensor(model.state_dict(), "model.recognition.pt")
        elif epoch - validating_accuracy_highest_epoch > 20:
            # 在多次训练后仍然没有刷新记录，结束训练
            print("stop training because validating accuracy dropped from highest in 20 epoches")
            break

    # 使用达到最高正确率时的模型状态
    print(f"highest validating accuracy: {validating_accuracy_highest}",
          f"from epoch {validating_accuracy_highest_epoch}")
    model.load_state_dict(load_tensor("model.recognition.pt"))

    # 检查测试集
    testing_accuracy_list = []
    for batch_x, _ in generate_inputs(testing_set, testing_set_single, 100):
        predicted = model(batch_x)
        testing_batch_accuracy = calc_accuracy(predicted)
        testing_accuracy_list.append(testing_batch_accuracy)
    testing_accuracy = sum(testing_accuracy_list) / len(testing_accuracy_list)
    print(f"testing accuracy: {testing_accuracy}")

    # 显示训练集和验证集的正确率变化
    pyplot.plot(training_accuracy_history, label="training_accuracy")
    pyplot.plot(validating_accuracy_history, label="validating_accuracy")
    pyplot.ylim(0, 1)
    pyplot.legend()
    pyplot.show()

def train_verify():
    """开始训练人脸认证模型"""
    # 创建人脸识别模型实例并加载训练好的参数
    recognize_model = FaceRecognitionModel().to(device)
    recognize_model.load_state_dict(load_tensor("model.recognition.pt"))
    recognize_model.eval()

    # 创建人脸认证模型实例
    model = FaceVerificationModel().to(device)

    # 创建损失计算器
    loss_function = model.loss_function

    # 创建参数调整器
    optimizer = torch.optim.Adam(model.parameters())

    # 记录训练集和验证集的正确率变化
    training_accuracy_history = []
    validating_accuracy_history = []

    # 记录最高的验证集正确率
    validating_accuracy_highest = -1
    validating_accuracy_highest_epoch = 0

    # 计算正确率的工具函数
    calc_accuracy = model.calc_accuracy

    # 读取人物列表，区分图片数量足够的人物和图片数量不足的人物
    # 图片数量不足的人物会作为负样本使用
    filenames = os.listdir("data")
    multiple_samples = []
    single_samples = []
    for filename in filenames:
        if int(filename.split('.')[-2]) >= MINIMAL_POSITIVE_SAMPLES:
            multiple_samples.append(filename)
        else:
            single_samples.append(filename)
    random.seed(123) # 让这里的顺序跟训练人脸识别模型时的顺序不一样
    random.shuffle(multiple_samples)
    random.shuffle(single_samples)
    total_multiple_samples = len(multiple_samples)
    total_single_samples = len(single_samples)

    # 分割训练集 (80%)，验证集 (10%) 和测试集 (10%)
    training_set = multiple_samples[:int(total_multiple_samples*0.8)]
    training_set_single = single_samples[:int(total_single_samples*0.8)]
    validating_set = multiple_samples[int(total_multiple_samples*0.8):int(total_multiple_samples*0.9)]
    validating_set_single = single_samples[int(total_single_samples*0.8):int(total_single_samples*0.9)]
    testing_set = multiple_samples[int(total_multiple_samples*0.9):]
    testing_set_single = single_samples[int(total_single_samples*0.9):]

    # 编码的缓存 { (文件名,索引值): 编码 }
    vector_cache = {}

    # 根据图片获取编码
    def get_vector(filename, image_index, image_tensor):
        key = (filename, image_index)
        vector = vector_cache.get(key)
        if vector is None:
            with torch.no_grad():
                vector = recognize_model(image_tensor.unsqueeze(0).to(device))[0].to("cpu")
            vector_cache[key] = vector
        return vector

    # 生成用于训练的输入
    # 返回 [ 同一人物编码的差异, 不同人物编码的差异, ... ]
    def generate_inputs(dataset_multiple, dataset_single, batch_size):
        # 枚举数据集，会重复 REPEAT_SAMPLES 次以减少随机选择导致的正确率浮动
        diff_tensors = []
        for base_filename in dataset_multiple * REPEAT_SAMPLES:
            # 读取基础人物的图片
            base_tensor = load_tensor(os.path.join("data", base_filename))
            base_tensors = list(enumerate(base_tensor))
            # 打乱顺序，然后两张两张图片的选取基础图片和正样本
            random.shuffle(base_tensors)
            for index in range(0, len(base_tensors)-1, 2):
                # 计算基础图片和正样本的编码差异并添加到列表
                anchor_image_index, anchor_tensor = base_tensors[index]
                positive_image_index, positive_tensor = base_tensors[index+1]
                anchor_vector = get_vector(base_filename, anchor_image_index, anchor_tensor)
                positive_vector = get_vector(base_filename, positive_image_index, positive_tensor)
                diff_tensors.append((anchor_vector - positive_vector).pow(2))
                # 随机选取负样本，计算差异并添加到列表
                negative_random_samples = NEGATIVE_SAMPLES
                while negative_random_samples > 0:
                    file_index = random.randint(0, len(dataset_multiple) + len(dataset_single) - 1)
                    if file_index < len(dataset_multiple):
                        filename = dataset_multiple[file_index]
                    else:
                        filename = dataset_single[file_index - len(dataset_multiple)]
                    if filename == base_filename:
                        continue # 跳过同一人物
                    target_tensor = load_tensor(os.path.join("data", filename))
                    image_index = random.randint(0, target_tensor.shape[0] - 1)
                    negative_vector = get_vector(filename, image_index, target_tensor[image_index])
                    diff_tensors.append((anchor_vector - negative_vector).pow(2))
                    negative_random_samples -= 1
                # 如果差异数量大于批次大小，则返回批次
                if len(diff_tensors) >= batch_size:
                    yield torch.stack(diff_tensors).to(device)
                    diff_tensors.clear()
        if diff_tensors:
            yield torch.stack(diff_tensors).to(device)

    # 开始训练过程
    for epoch in range(1, 20):
        print(f"epoch: {epoch}")

        # 根据训练集训练并修改参数
        # 切换模型到训练模式
        model.train()
        training_accuracy_list = []
        for index, batch_x in enumerate(
                generate_inputs(training_set, training_set_single, 400)):
            # 计算预测值
            predicted = model(batch_x)
            # 计算损失
            loss = loss_function(predicted)
            # 从损失自动微分求导函数值
            loss.backward()
            # 使用参数调整器调整参数
            optimizer.step()
            # 清空导函数值
            optimizer.zero_grad()
            # 记录这一个批次的正确率，torch.no_grad 代表临时禁用自动微分功能
            with torch.no_grad():
                training_batch_accuracy = calc_accuracy(predicted)
            # 输出批次正确率
            training_accuracy_list.append(training_batch_accuracy)
            print(f"epoch: {epoch}, batch: {index}, accuracy: {training_batch_accuracy}")
        training_accuracy = sum(training_accuracy_list) / len(training_accuracy_list)
        training_accuracy_history.append(training_accuracy)
        print(f"training accuracy: {training_accuracy}")

        # 检查验证集
        # 切换模型到验证模式
        model.eval()
        validating_accuracy_list = []
        for batch_x in generate_inputs(validating_set, validating_set_single, 100):
            predicted = model(batch_x)
            validating_batch_accuracy = calc_accuracy(predicted)
            validating_accuracy_list.append(validating_batch_accuracy)
            # 释放 predicted 占用的显存避免显存不足的错误
            predicted = None
        validating_accuracy = sum(validating_accuracy_list) / len(validating_accuracy_list)
        validating_accuracy_history.append(validating_accuracy)
        print(f"validating accuracy: {validating_accuracy}")

        # 记录最高的验证集正确率与当时的模型状态，判断是否在多次训练后仍然没有刷新记录
        # 因为验证集的负样本是随机选择的，允许 1% 的波动使得模型可以训练更多次
        if (validating_accuracy + 0.01) > validating_accuracy_highest:
            if validating_accuracy > validating_accuracy_highest:
                validating_accuracy_highest = validating_accuracy
                print("highest validating accuracy updated")
            else:
                print("highest validating accuracy not dropped")
            validating_accuracy_highest_epoch = epoch
            save_tensor(model.state_dict(), "model.verification.pt")
        elif epoch - validating_accuracy_highest_epoch > 20:
            # 在多次训练后仍然没有刷新记录，结束训练
            print("stop training because validating accuracy dropped from highest in 20 epoches")
            break

    # 使用达到最高正确率时的模型状态
    print(f"highest validating accuracy: {validating_accuracy_highest}",
          f"from epoch {validating_accuracy_highest_epoch}")
    model.load_state_dict(load_tensor("model.verification.pt"))

    # 检查测试集
    testing_accuracy_list = []
    for batch_x in generate_inputs(testing_set, testing_set_single, 100):
        predicted = model(batch_x)
        testing_batch_accuracy = calc_accuracy(predicted)
        testing_accuracy_list.append(testing_batch_accuracy)
    testing_accuracy = sum(testing_accuracy_list) / len(testing_accuracy_list)
    print(f"testing accuracy: {testing_accuracy}")

    # 显示训练集和验证集的正确率变化
    pyplot.plot(training_accuracy_history, label="training_accuracy")
    pyplot.plot(validating_accuracy_history, label="validating_accuracy")
    pyplot.ylim(0, 1)
    pyplot.legend()
    pyplot.show()

def main():
    """主函数"""
    if len(sys.argv) < 2:
        print(f"Please run: {sys.argv[0]} prepare|train")
        exit()

    # 给随机数生成器分配一个初始值，使得每次运行都可以生成相同的随机数
    # 这是为了让过程可重现，你也可以选择不这样做
    random.seed(0)
    torch.random.manual_seed(0)

    # 根据命令行参数选择操作
    operation = sys.argv[1]
    if operation == "prepare":
        prepare()
    elif operation == "train":
        train()
    elif operation == "train-verify":
        train_verify()
    else:
        raise ValueError(f"Unsupported operation: {operation}")

if __name__ == "__main__":
    main()
